# -*- coding: utf-8 -*-
# @Author: lidongdong
# @time  : 19-1-24 下午9:51
# @file  : data_utils.py


"""utils to load all training data"""

import json
import h5py
import cv2
import yaml


def load_vocab(vocab_size=10000, vocab_file="../data/vocab.json"):
    """ load vocabulary"""
    print "[*]  Loading vocab from {}".format(vocab_file)
    with open(vocab_file) as f:
        vocab = yaml.load(f)
    items = vocab.items()
    items = filter(lambda x: x[1] < vocab_size, items)
    vocab = dict(items)

    return vocab


def load_image(image_h5_filename, load_to_memory=False):
    print "[*] Loading data from {}".format(image_h5_filename)
    images = h5py.File(image_h5_filename)["images"]
    print "done"
    if load_to_memory:
        images = images[:, :, :, :]     # sample_num, height, width, channel
    return images


def load_caption(caption_json_filename):
    with open(caption_json_filename) as f:
        content = json.load(f)["content"]
    return content


def process_images(images, crop_scale_no=2):
    # 原始图片大小为 256 * 256 现在可能需要将其resize或者是crop到 224 * 224
    if crop_scale_no == 0:
        images = images[:, 16: 240, 16: 240, :]
    elif crop_scale_no == 1:
        images = cv2.reshape(images, [224, 224])

    return images / 127.5 - 1.



